import bpy

# ---------- Helpers (prefs/colors) ----------
def _addon_key():
    # Top-level package name for preferences lookup
    return (__package__ or __name__.split('.')[0])

def _get_addon_prefs():
    try:
        return bpy.context.preferences.addons[_addon_key()].preferences
    except Exception:
        return None

def get_mode_color(mode: str):
    prefs = _get_addon_prefs()
    defaults = {
        'Switch': (0.086, 0.157, 0.247),
        'Mix': (0.086, 0.247, 0.157),
        'Random': (0.247, 0.157, 0.086),
        'ColorMixer': (0.247, 0.086, 0.157),
        'Master': (0.157, 0.086, 0.247),
    }
    if prefs:
        try:
            if mode == 'Switch': return tuple(prefs.color_switch)
            if mode == 'Mix': return tuple(prefs.color_mix)
            if mode == 'Random': return tuple(prefs.color_random)
            if mode == 'ColorMixer': return tuple(prefs.color_colormixer)
            if mode == 'Master': return tuple(prefs.color_master)
        except Exception:
            pass
    return defaults.get(mode, (0.2, 0.2, 0.2))

# ---------- Builders ----------
def create_shader_switcher_node_group(shader_count=3, mode='Switch', color_count=3):
    if mode in ['Switch', 'Mix', 'Random', 'Master']:
        if mode == 'Random' and shader_count < 2:
            shader_count = 2
        if shader_count < 1 or shader_count > 50:
            raise RuntimeError(f"Invalid shader count: {shader_count}. Must be between 1 and 50.")
        input_count_to_use = shader_count
    elif mode == 'ColorMixer':
        if color_count < 1 or color_count > 50:
            raise RuntimeError(f"Invalid color count: {color_count}. Must be between 1 and 50.")
        input_count_to_use = color_count
    else:
        raise RuntimeError("Invalid mode: must be 'Switch', 'Mix', 'Random', 'ColorMixer', or 'Master'.")

    if mode in ['Switch', 'Mix', 'Random']:
        initial_display_label = f"Shader {mode.capitalize()} ({input_count_to_use})"
    elif mode == 'ColorMixer':
        initial_display_label = f"Color Mixer ({input_count_to_use})"
    else:
        initial_display_label = f"Master Material ({input_count_to_use})"

    group_name = initial_display_label
    i = 0
    while group_name in bpy.data.node_groups:
        i += 1
        group_name = f"{initial_display_label}.{i:03d}"

    group = bpy.data.node_groups.new(group_name, 'ShaderNodeTree')
    group["input_count"] = input_count_to_use
    group["mode"] = mode

    if mode == 'Switch':
        return create_index_switcher(group, shader_count)
    if mode == 'Mix':
        return create_layer_mixer(group, shader_count)
    if mode == 'Random':
        return create_random_switcher(group, shader_count)
    if mode == 'ColorMixer':
        return create_color_mixer(group, color_count)
    if mode == 'Master':
        return create_master_material_router(group, shader_count)
    return group

def create_random_switcher(group, shader_count):
    group.nodes.clear(); group.interface.clear()

    seed_socket = group.interface.new_socket(name="Seed", socket_type='NodeSocketFloat', in_out='INPUT')
    seed_socket.default_value = 0.0; seed_socket.min_value = 0.0; seed_socket.max_value = 1000.0
    seed_socket.description = "Adjust to change the random pattern per object"

    for i in range(shader_count):
        s = group.interface.new_socket(name=f"Input {i+1}", socket_type='NodeSocketShader', in_out='INPUT')
        s.description = f"Shader input #{i+1}."
    group.interface.new_socket(name="Output", socket_type='NodeSocketShader', in_out='OUTPUT')

    nodes = group.nodes; links = group.links
    group_input = nodes.new("NodeGroupInput"); group_input.location = (-800, 0)
    group_output = nodes.new("NodeGroupOutput"); group_output.location = (600, 0)
    obj_info = nodes.new("ShaderNodeObjectInfo"); obj_info.location = (-600, 200)

    add_seed = nodes.new("ShaderNodeMath"); add_seed.operation = 'ADD'; add_seed.location = (-400, 200)
    links.new(obj_info.outputs["Random"], add_seed.inputs[0])
    links.new(group_input.outputs["Seed"], add_seed.inputs[1])

    math_fract = nodes.new("ShaderNodeMath"); math_fract.operation = 'FRACT'; math_fract.location = (-200, 200)
    links.new(add_seed.outputs[0], math_fract.inputs[0])

    map_range = nodes.new("ShaderNodeMapRange"); map_range.location = (0, 200)
    map_range.inputs['From Min'].default_value = 0.0; map_range.inputs['From Max'].default_value = 1.0
    map_range.inputs['To Min'].default_value = 0.0; map_range.inputs['To Max'].default_value = shader_count - 0.001
    links.new(math_fract.outputs[0], map_range.inputs[0])

    math_round = nodes.new("ShaderNodeMath"); math_round.operation = 'ROUND'; math_round.location = (200, 200)
    links.new(map_range.outputs[0], math_round.inputs[0])

    prev_mix = None
    if shader_count > 1:
        for i in range(shader_count - 1):
            compare = nodes.new("ShaderNodeMath"); compare.operation = 'COMPARE'
            compare.inputs[1].default_value = i; compare.inputs[2].default_value = 0.0
            compare.location = (200, -i * 150 - 200)
            links.new(math_round.outputs[0], compare.inputs[0])

            mix = nodes.new("ShaderNodeMixShader"); mix.location = (400, -i * 150 - 200)
            links.new(compare.outputs[0], mix.inputs[0])

            if prev_mix:
                links.new(prev_mix.outputs[0], mix.inputs[1])
            else:
                links.new(group_input.outputs[f"Input {i+1}"], mix.inputs[1])
            links.new(group_input.outputs[f"Input {i+2}"], mix.inputs[2])
            prev_mix = mix
        links.new(prev_mix.outputs[0], group_output.inputs[0])
    else:
        links.new(group_input.outputs["Input 1"], group_output.inputs[0])
    return group

def create_index_switcher(group, shader_count):
    group.nodes.clear(); group.interface.clear()

    index_socket = group.interface.new_socket(name="Index", socket_type='NodeSocketInt', in_out='INPUT')
    index_socket.description = "Controls which shader is active (1..N)."
    index_socket.default_value = 1; index_socket.min_value = 1; index_socket.max_value = shader_count if shader_count > 0 else 1

    for i in range(shader_count):
        s = group.interface.new_socket(name=f"Input {i+1}", socket_type='NodeSocketShader', in_out='INPUT')
        s.description = f"Shader input #{i+1}."
    out = group.interface.new_socket(name="Output", socket_type='NodeSocketShader', in_out='OUTPUT')
    out.description = "The active shader output based on Index."

    nodes = group.nodes; links = group.links
    group_input = nodes.new("NodeGroupInput"); group_input.location = (-800, 0)
    group_output = nodes.new("NodeGroupOutput"); group_output.location = (800, 0)

    previous_output = None
    for i in range(1, shader_count + 1):
        y = -(i - 1) * 200
        compare = nodes.new("ShaderNodeMath"); compare.operation = 'COMPARE'
        compare.inputs[1].default_value = float(i); compare.inputs[2].default_value = 0.01
        compare.location = (-600, y)
        links.new(group_input.outputs["Index"], compare.inputs[0])

        mix = nodes.new("ShaderNodeMixShader"); mix.location = (-300, y)
        links.new(compare.outputs[0], mix.inputs[0])

        if previous_output:
            links.new(previous_output.outputs[0], mix.inputs[1])
        else:
            links.new(group_input.outputs[f"Input {i}"], mix.inputs[1])
        links.new(group_input.outputs[f"Input {i}"], mix.inputs[2])
        previous_output = mix
    links.new(previous_output.outputs[0], group_output.inputs['Output'])
    return group

def create_layer_mixer(group, shader_count, mix_values=None):
    if mix_values is None:
        mix_values = {}
    group.nodes.clear(); group.interface.clear()

    for i in reversed(range(shader_count)):
        layer_num = i + 1
        shader_socket = group.interface.new_socket(name=f"Input {layer_num}", socket_type='NodeSocketShader', in_out='INPUT')
        shader_socket.description = f"Layer {layer_num} shader input. Higher numbers appear on top of lower numbers."
        mix_socket = group.interface.new_socket(name=f"Factor {layer_num}", socket_type='NodeSocketFloat', in_out='INPUT')
        mix_socket.description = f"Controls blend amount of Layer {layer_num}. 0=transparent, 1=visible."
        mix_socket.default_value = mix_values.get(mix_socket.name, 0.0)
        mix_socket.min_value = 0.0; mix_socket.max_value = 1.0; mix_socket.subtype = 'FACTOR'

    out = group.interface.new_socket(name="Output", socket_type='NodeSocketShader', in_out='OUTPUT')
    out.description = "The final blended result of all layers."

    nodes = group.nodes; links = group.links
    group_input = nodes.new("NodeGroupInput"); group_input.location = (-800, 0)
    group_output = nodes.new("NodeGroupOutput"); group_output.location = (800, 0)

    transparent = nodes.new("ShaderNodeBsdfTransparent"); transparent.location = (-600, 0)
    previous_output = transparent; spacing = 300

    for i in range(shader_count):
        layer_num = i + 1; y = -i * spacing
        mix = nodes.new("ShaderNodeMixShader"); mix.location = (-300, y)
        links.new(group_input.outputs[f"Factor {layer_num}"], mix.inputs[0])
        links.new(previous_output.outputs[0], mix.inputs[1])
        links.new(group_input.outputs[f"Input {layer_num}"], mix.inputs[2])
        previous_output = mix
    links.new(previous_output.outputs[0], group_output.inputs['Output'])
    return group

def create_color_mixer(group, color_count, input_state_map=None):
    if input_state_map is None:
        input_state_map = {}
    group.nodes.clear(); group.interface.clear()

    for i in reversed(range(color_count)):
        layer_num = i + 1
        color_socket_name = f"Color {layer_num}"
        factor_socket_name = f"Factor {layer_num}"

        color_socket = group.interface.new_socket(name=color_socket_name, socket_type='NodeSocketColor', in_out='INPUT')
        color_socket.description = f"Layer {layer_num} color input. Higher numbers appear on top of lower numbers."
        if color_socket_name in input_state_map and not isinstance(input_state_map[color_socket_name], bpy.types.NodeSocket):
            color_socket.default_value = input_state_map[color_socket_name]
        else:
            color_socket.default_value = (0.0, 0.0, 0.0, 1.0)

        mix_socket = group.interface.new_socket(name=factor_socket_name, socket_type='NodeSocketFloat', in_out='INPUT')
        mix_socket.description = f"Controls blend amount of Layer {layer_num}. 0=transparent, 1=visible."
        if factor_socket_name in input_state_map and not isinstance(input_state_map[factor_socket_name], bpy.types.NodeSocket):
            mix_socket.default_value = input_state_map[factor_socket_name]
        else:
            mix_socket.default_value = 1.0
        mix_socket.min_value = 0.0; mix_socket.max_value = 1.0; mix_socket.subtype = 'FACTOR'

    output_socket = group.interface.new_socket(name="Output", socket_type='NodeSocketColor', in_out='OUTPUT')
    output_socket.description = "The final blended color output."

    nodes = group.nodes; links = group.links
    group_input = nodes.new("NodeGroupInput"); group_input.location = (-800, 0)
    group_output = nodes.new("NodeGroupOutput"); group_output.location = (800, 0)

    transparent_color = nodes.new("ShaderNodeMixRGB")
    transparent_color.blend_type = 'MIX'
    transparent_color.inputs['Color1'].default_value = (0.0, 0.0, 0.0, 0.0)
    transparent_color.inputs['Color2'].default_value = (0.0, 0.0, 0.0, 0.0)
    transparent_color.inputs['Fac'].default_value = 1.0
    transparent_color.location = (-600, 0)

    previous_output = transparent_color.outputs['Color']
    spacing = 300
    for i in range(color_count):
        layer_num = i + 1; y = -i * spacing
        mix = nodes.new("ShaderNodeMixRGB"); mix.location = (-300, y)

        old_blend_mode_prop_name = f"BlendMode_{layer_num}"
        if old_blend_mode_prop_name in input_state_map:
            mix.blend_type = input_state_map[old_blend_mode_prop_name]

        group[f"MixNode_{layer_num}"] = mix.name
        links.new(group_input.outputs[f"Factor {layer_num}"], mix.inputs['Fac'])
        links.new(previous_output, mix.inputs['Color1'])
        links.new(group_input.outputs[f"Color {layer_num}"], mix.inputs['Color2'])
        previous_output = mix.outputs['Color']

    links.new(previous_output, group_output.inputs['Output'])
    return group

def create_master_material_router(group, shader_count, input_state_map=None):
    group.nodes.clear(); group.interface.clear()
    group["input_count"] = shader_count
    group["mode"] = 'Master'

    iface = group.interface
    default_socket = iface.new_socket(name="Default Shader", socket_type='NodeSocketShader', in_out='INPUT')
    default_socket.description = "Fallback shader used when no branch ID matches the object's Pass Index"

    ids = []
    for i in range(1, shader_count + 1):
        key = f"id_{i}"
        if input_state_map and key in input_state_map and isinstance(input_state_map[key], (int, float)):
            ids.append(int(input_state_map[key]))
        elif key in group:
            try:
                ids.append(int(group[key]))
            except Exception:
                ids.append(i)
        else:
            ids.append(i)

    for i in range(1, shader_count + 1):
        s = iface.new_socket(name=f"Input {i}", socket_type='NodeSocketShader', in_out='INPUT')
        s.description = f"Shader for objects whose Pass Index equals {ids[i-1]}"
    out_sock = iface.new_socket(name="Output", socket_type='NodeSocketShader', in_out='OUTPUT')
    out_sock.description = "Final shader after routing by Object Pass Index"

    nodes = group.nodes; links = group.links
    inp = nodes.new("NodeGroupInput"); inp.location = (-900, 0)
    out = nodes.new("NodeGroupOutput"); out.location = (700, 0)
    obj_info = nodes.new("ShaderNodeObjectInfo"); obj_info.location = (-700, 300)

    prev_shader = inp.outputs["Default Shader"]
    for i in range(1, shader_count + 1):
        y = -(i - 1) * 220
        cmp = nodes.new("ShaderNodeMath"); cmp.operation = 'COMPARE'; cmp.location = (-500, y)
        cmp.inputs[1].default_value = float(ids[i-1]); cmp.inputs[2].default_value = 0.5
        links.new(obj_info.outputs["Object Index"], cmp.inputs[0])

        mix = nodes.new("ShaderNodeMixShader"); mix.location = (-250, y)
        links.new(cmp.outputs[0], mix.inputs[0])
        links.new(prev_shader, mix.inputs[1])
        links.new(inp.outputs[f"Input {i}"], mix.inputs[2])
        prev_shader = mix.outputs[0]
        group[f"id_{i}"] = int(ids[i-1])

    links.new(prev_shader, out.inputs["Output"])

    # refresh socket descriptions
    for i in range(1, shader_count + 1):
        try:
            item = iface.items_tree.get(f"Input {i}")
            if item:
                item.description = f"Shader for objects whose Pass Index equals {group.get(f'id_{i}', i)}"
        except Exception:
            pass
    return group

def update_shader_switcher_node_group(group, new_input_count, mode, input_state_map=None):
    if not group:
        raise RuntimeError("No node group provided for update")
    group["mode"] = mode; group["input_count"] = new_input_count

    if mode == 'Switch':
        new_group = create_index_switcher(group, new_input_count)
        try:
            if "Index" in new_group.interface.items_tree and input_state_map and "Index" in input_state_map:
                if not isinstance(input_state_map["Index"], bpy.types.NodeSocket):
                    new_group.interface.items_tree["Index"].default_value = input_state_map["Index"]
        except Exception:
            pass
    elif mode == 'Mix':
        new_group = create_layer_mixer(group, new_input_count, input_state_map)
    elif mode == 'Random':
        new_group = create_random_switcher(group, new_input_count)
        try:
            if "Seed" in new_group.interface.items_tree and input_state_map and "Seed" in input_state_map:
                if not isinstance(input_state_map["Seed"], bpy.types.NodeSocket):
                    new_group.interface.items_tree["Seed"].default_value = input_state_map["Seed"]
        except Exception:
            pass
    elif mode == 'ColorMixer':
        new_group = create_color_mixer(group, new_input_count, input_state_map)
    elif mode == 'Master':
        new_group = create_master_material_router(group, new_input_count, input_state_map)
    else:
        raise RuntimeError(f"Unknown mode: {mode}")
    return new_group

# ---------- Callback & Operators ----------
def update_shader_node_group_mode(self, context):
    node_group = self.id_data  # NodeTree
    active_node = None
    if context.space_data and context.space_data.type == 'NODE_EDITOR' and context.space_data.edit_tree:
        for node in context.space_data.edit_tree.nodes:
            if node.type == 'GROUP' and node.node_tree == node_group and node.select:
                active_node = node; break
    if not active_node:
        print(f"Warning: Node group '{node_group.name}' property changed but no active node instance found.")
        return

    if node_group.users > 1:
        original_group_name = node_group.name
        new_node_group = node_group.copy()
        active_node.node_tree = new_node_group
        node_group = new_node_group
        print(f"Duplicated node group '{original_group_name}' made single user as '{node_group.name}'.")

    old_mode = node_group.get("mode")
    current_input_count = node_group.get("input_count", 3)

    connected_input_sources = {}
    output_connections = []

    for input_socket in active_node.inputs:
        if input_socket.is_linked:
            for link in input_socket.links:
                connected_input_sources[input_socket.name] = link.from_socket
        else:
            if hasattr(input_socket, 'default_value'):
                if input_socket.type == 'RGBA':
                    connected_input_sources[input_socket.name] = tuple(input_socket.default_value)
                else:
                    connected_input_sources[input_socket.name] = input_socket.default_value

    if old_mode == 'Master':
        for i in range(1, current_input_count + 1):
            key = f"id_{i}"
            if key in node_group:
                try:
                    connected_input_sources[key] = int(node_group[key])
                except Exception:
                    connected_input_sources[key] = i

    if old_mode == 'ColorMixer':
        for i in range(1, current_input_count + 1):
            mix_node_name_prop = f"MixNode_{i}"
            if mix_node_name_prop in node_group and node_group.nodes.get(node_group[mix_node_name_prop]):
                internal_mix_node = node_group.nodes[node_group[mix_node_name_prop]]
                if hasattr(internal_mix_node, 'blend_type'):
                    connected_input_sources[f"BlendMode_{i}"] = internal_mix_node.blend_type
                if f"BlendMode_{i}" in node_group:
                    del node_group[f"BlendMode_{i}"]

    for output_socket in active_node.outputs:
        if output_socket.is_linked:
            for link in output_socket.links:
                output_connections.append((output_socket.name, link.to_socket))

    new_mode = self.mode
    update_shader_switcher_node_group(node_group, current_input_count, new_mode, connected_input_sources)

    if new_mode in ['Switch', 'Mix', 'Random']:
        display_label = f"Shader {new_mode.capitalize()} ({current_input_count})"
    elif new_mode == 'ColorMixer':
        display_label = f"Color Mixer ({current_input_count})"
    else:
        display_label = f"Master Material ({current_input_count})"

    active_node.label = display_label
    active_node.use_custom_color = True
    active_node.color = get_mode_color(new_mode)
    node_group.name = display_label

    node_tree_in_editor = context.space_data.edit_tree

    for old_link in list(node_tree_in_editor.links):
        if old_link.from_node == active_node or old_link.to_node == active_node:
            node_tree_in_editor.links.remove(old_link)

    for new_input_socket in active_node.inputs:
        if new_input_socket.name in connected_input_sources:
            source = connected_input_sources[new_input_socket.name]
            if isinstance(source, bpy.types.NodeSocket):
                try: node_tree_in_editor.links.new(source, new_input_socket)
                except Exception as e: print(f"Failed to re-link input {new_input_socket.name}: {e}")
            else:
                try: new_input_socket.default_value = source
                except AttributeError: pass

    for identifier, to_socket in output_connections:
        if identifier in active_node.outputs:
            try: node_tree_in_editor.links.new(active_node.outputs[identifier], to_socket)
            except Exception as e: print(f"Failed to re-link output {identifier}: {e}")

    print(f"Node Group '{node_group.name}' mode changed to: {new_mode}")

class NODE_OT_add_shader_switcher(bpy.types.Operator):
    """Add a Shader Switcher/Color Mixer/Master Material node group."""
    bl_idname = "node.add_shader_switcher"
    bl_label = "Add Shader/Color Switcher"
    bl_description = "Adds a Switcher, Mixer, Randomizer, Color Mixer, or Master Material node"
    bl_options = {'REGISTER', 'UNDO'}

    node_type: bpy.props.EnumProperty(
        name="Type",
        description="Choose the type of node to add",
        items=[
            ('Switch', "Shader Switcher", "Add a Shader Switcher node"),
            ('Mix', "Shader Mixer", "Add a Shader Mixer node"),
            ('Random', "Shader Randomizer", "Add a Shader Randomizer node"),
            ('ColorMixer', "Color Mixer", "Add a Color Mixer node"),
            ('Master', "Master Material", "Add a Master Material router node"),
        ],
        default='Switch',
        options={'HIDDEN'}
    )
    initial_count: bpy.props.IntProperty(
        name="Initial Inputs",
        description="How many inputs to start with",
        default=3, min=1, max=50, options={'HIDDEN'}
    )

    @classmethod
    def poll(cls, context):
        return (context.space_data and context.space_data.type == 'NODE_EDITOR' and
                context.space_data.tree_type == 'ShaderNodeTree')

    def execute(self, context):
        try:
            group = create_shader_switcher_node_group(
                shader_count=self.initial_count,
                color_count=self.initial_count,
                mode=self.node_type
            )
            node = context.space_data.edit_tree.nodes.new("ShaderNodeGroup")
            node.node_tree = group
            node.location = context.space_data.cursor_location
            node.use_custom_color = True

            context.space_data.edit_tree.nodes.active = node
            node.select = True
            node.node_tree.shader_switcher_props.mode = self.node_type

            self.report({'INFO'}, f"Created {node.label} node.")
            return {'FINISHED'}
        except Exception as e:
            self.report({'ERROR'}, str(e))
            return {'CANCELLED'}

class NODE_OT_add_shader_input(bpy.types.Operator):
    """Add a new input to the selected node group."""
    bl_idname = "node.add_shader_switcher_input"
    bl_label = "Add Input"
    bl_description = "Adds a new shader/color input to the selected node group"
    bl_options = {'REGISTER', 'UNDO'}

    @classmethod
    def poll(cls, context):
        space = context.space_data
        if not space or space.type != 'NODE_EDITOR' or not space.tree_type == 'ShaderNodeTree':
            return False
        node = space.node_tree.nodes.active
        return (node and node.bl_idname == 'ShaderNodeGroup' and "mode" in node.node_tree and
                node.node_tree.get("input_count", 0) < 50)

    def execute(self, context):
        try:
            node = context.space_data.edit_tree.nodes.active
            if not (node and node.bl_idname == 'ShaderNodeGroup' and "mode" in node.node_tree):
                self.report({'ERROR'}, "Select one of the Shader Switcher nodes first.")
                return {'CANCELLED'}

            group = node.node_tree
            if group.users > 1:
                new_group = group.copy()
                node.node_tree = new_group
                group = new_group

            mode = group.get("mode", 'Switch')
            current_input_count = group.get("input_count", 0)
            if current_input_count >= 50:
                self.report({'WARNING'}, "Maximum limit of 50 inputs reached")
                return {'CANCELLED'}
            next_count = current_input_count + 1

            input_state_map = {}
            for input_socket in node.inputs:
                if input_socket.is_linked:
                    for link in input_socket.links:
                        input_state_map[input_socket.name] = link.from_socket
                else:
                    if hasattr(input_socket, 'default_value'):
                        if input_socket.type == 'RGBA':
                            input_state_map[input_socket.name] = tuple(input_socket.default_value)
                        else:
                            input_state_map[input_socket.name] = input_socket.default_value

            if mode == 'ColorMixer':
                for i in range(1, current_input_count + 1):
                    mix_node_name_prop = f"MixNode_{i}"
                    if mix_node_name_prop in group and group.nodes.get(group[mix_node_name_prop]):
                        internal_mix_node = group.nodes[group[mix_node_name_prop]]
                        if hasattr(internal_mix_node, 'blend_type'):
                            input_state_map[f"BlendMode_{i}"] = internal_mix_node.blend_type

            if mode == 'Master':
                for i in range(1, current_input_count + 1):
                    key = f"id_{i}"
                    if key in group:
                        try:
                            input_state_map[key] = int(group[key])
                        except Exception:
                            input_state_map[key] = i

            output_connections = []
            for socket in node.outputs:
                if socket.is_linked:
                    for link in socket.links:
                        output_connections.append((socket.name, link.to_socket))

            update_shader_switcher_node_group(group, next_count, mode, input_state_map)

            if mode in ['Switch', 'Mix', 'Random']:
                display_label = f"Shader {mode.capitalize()} ({next_count})"
            elif mode == 'ColorMixer':
                display_label = f"Color Mixer ({next_count})"
            else:
                display_label = f"Master Material ({next_count})"

            node.use_custom_color = True
            node.color = get_mode_color(mode)
            node.label = display_label
            node.node_tree.name = display_label

            if node.node_tree and context.space_data.edit_tree:
                node_tree_in_editor = context.space_data.edit_tree

                for old_link in list(node_tree_in_editor.links):
                    if old_link.from_node == node or old_link.to_node == node:
                        node_tree_in_editor.links.remove(old_link)

                for new_input_socket in node.inputs:
                    if new_input_socket.name in input_state_map:
                        source = input_state_map[new_input_socket.name]
                        if isinstance(source, bpy.types.NodeSocket):
                            try:
                                node_tree_in_editor.links.new(source, new_input_socket)
                            except Exception as e:
                                print(f"Failed to re-link input {new_input_socket.name}: {e}")
                        else:
                            try:
                                new_input_socket.default_value = source
                            except AttributeError:
                                pass

                for identifier, to_socket in output_connections:
                    if identifier in node.outputs:
                        try:
                            node_tree_in_editor.links.new(node.outputs[identifier], to_socket)
                        except Exception as e:
                            print(f"Failed to re-link output {identifier}: {e}")

            self.report({'INFO'}, f"Added input. Now {display_label}")
            return {'FINISHED'}
        except Exception as e:
            self.report({'ERROR'}, f"Failed to add input: {str(e)}")
            return {'CANCELLED'}

class NODE_OT_remove_shader_input(bpy.types.Operator):
    """Remove a specific input from the selected node group."""
    bl_idname = "node.remove_shader_input"
    bl_label = "Remove Input"
    bl_description = "Remove the selected shader/color input from the node group"
    bl_options = {'REGISTER', 'UNDO'}

    input_index_to_remove: bpy.props.IntProperty(
        name="Input Index to Remove",
        description="1-based index of the input to remove."
    )

    @classmethod
    def poll(cls, context):
        space = context.space_data
        if not space or space.type != 'NODE_EDITOR' or not space.tree_type == 'ShaderNodeTree':
            return False
        node = space.node_tree.nodes.active
        if not (node and node.bl_idname == 'ShaderNodeGroup' and "mode" in node.node_tree):
            return False

        group = node.node_tree
        mode = group.get("mode", 'Switch')
        current_input_count = group.get("input_count", 0)
        min_inputs_for_mode = 2 if mode == 'Random' else 1
        return current_input_count > min_inputs_for_mode

    def execute(self, context):
        try:
            node = context.space_data.edit_tree.nodes.active
            if not (node and node.bl_idname == 'ShaderNodeGroup' and "mode" in node.node_tree):
                self.report({'ERROR'}, "Select one of the Shader Switcher nodes first.")
                return {'CANCELLED'}

            group = node.node_tree
            if group.users > 1:
                new_group = group.copy()
                node.node_tree = new_group
                group = new_group

            mode = group.get("mode", 'Switch')
            current_input_count = group.get("input_count", 0)
            min_inputs_for_mode = 2 if mode == 'Random' else 1
            if current_input_count <= min_inputs_for_mode:
                self.report({'WARNING'}, f"Cannot remove input: minimum of {min_inputs_for_mode} for '{mode}'.")
                return {'CANCELLED'}

            input_state_map_raw = {}
            for input_socket in node.inputs:
                if input_socket.is_linked:
                    for link in input_socket.links:
                        input_state_map_raw[input_socket.name] = link.from_socket
                else:
                    if hasattr(input_socket, 'default_value'):
                        if input_socket.type == 'RGBA':
                            input_state_map_raw[input_socket.name] = tuple(input_socket.default_value)
                        else:
                            input_state_map_raw[input_socket.name] = input_socket.default_value

            if mode == 'ColorMixer':
                for i in range(1, current_input_count + 1):
                    mix_node_name_prop = f"MixNode_{i}"
                    if mix_node_name_prop in group and group.nodes.get(group[mix_node_name_prop]):
                        internal_mix_node = group.nodes[group[mix_node_name_prop]]
                        if hasattr(internal_mix_node, 'blend_type'):
                            input_state_map_raw[f"BlendMode_{i}"] = internal_mix_node.blend_type

            new_input_count = current_input_count - 1
            name_remapping = {}
            old_socket_type_prefix = "Input" if mode in ['Switch', 'Mix', 'Random', 'Master'] else "Color"

            current_new_index = 1
            for i in range(1, current_input_count + 1):
                if i == self.input_index_to_remove:
                    continue
                old_main_input_name = f"{old_socket_type_prefix} {i}"
                new_main_input_name = f"{old_socket_type_prefix} {current_new_index}"
                name_remapping[old_main_input_name] = new_main_input_name

                if mode in ['Mix', 'ColorMixer']:
                    old_factor_name = f"Factor {i}"; new_factor_name = f"Factor {current_new_index}"
                    name_remapping[old_factor_name] = new_factor_name

                if mode == 'ColorMixer':
                    old_blend_mode_prop = f"BlendMode_{i}"; new_blend_mode_prop = f"BlendMode_{current_new_index}"
                    name_remapping[old_blend_mode_prop] = new_blend_mode_prop

                if mode == 'Master':
                    old_id_key = f"id_{i}"; new_id_key = f"id_{current_new_index}"
                    name_remapping[old_id_key] = new_id_key

                current_new_index += 1

            input_state_map_reindexed = {}
            for old_name, value in input_state_map_raw.items():
                if old_name in name_remapping:
                    input_state_map_reindexed[name_remapping[old_name]] = value

            output_connections = []
            for socket in node.outputs:
                if socket.is_linked:
                    for link in socket.links:
                        output_connections.append((socket.name, link.to_socket))

            update_shader_switcher_node_group(group, new_input_count, mode, input_state_map_reindexed)

            if mode in ['Switch', 'Mix', 'Random']:
                display_label = f"Shader {mode.capitalize()} ({new_input_count})"
            elif mode == 'ColorMixer':
                display_label = f"Color Mixer ({new_input_count})"
            else:
                display_label = f"Master Material ({new_input_count})"

            node.label = display_label
            node.node_tree.name = display_label
            node.use_custom_color = True
            node.color = get_mode_color(mode)

            if node.node_tree and context.space_data.edit_tree:
                node_tree_in_editor = context.space_data.edit_tree

                for old_link in list(node_tree_in_editor.links):
                    if old_link.from_node == node or old_link.to_node == node:
                        node_tree_in_editor.links.remove(old_link)

                for new_input_socket in node.inputs:
                    if new_input_socket.name in input_state_map_reindexed:
                        source = input_state_map_reindexed[new_input_socket.name]
                        if isinstance(source, bpy.types.NodeSocket):
                            try:
                                node_tree_in_editor.links.new(source, new_input_socket)
                            except Exception as e:
                                print(f"Failed to re-link input {new_input_socket.name}: {e}")
                        else:
                            try:
                                new_input_socket.default_value = source
                            except AttributeError:
                                pass

                for identifier, to_socket in output_connections:
                    if identifier in node.outputs:
                        try:
                            node_tree_in_editor.links.new(node.outputs[identifier], to_socket)
                        except Exception as e:
                            print(f"Failed to re-link output {identifier}: {e}")

            self.report({'INFO'}, f"Removed input. Now {display_label}")
            return {'FINISHED'}
        except Exception as e:
            self.report({'ERROR'}, f"Failed to remove input: {str(e)}")
            return {'CANCELLED'}

class OBJECT_OT_assign_selection_to_branch(bpy.types.Operator):
    """Set the Pass Index of selected objects."""
    bl_idname = "object.assign_selection_to_branch"
    bl_label = "Assign Selection to Branch"
    bl_description = "Set the Pass Index of selected objects to the chosen branch ID"
    bl_options = {'REGISTER', 'UNDO'}

    pass_index: bpy.props.IntProperty(name="Pass Index", min=0, description="Object Pass Index to assign")

    @classmethod
    def poll(cls, context):
        return context.selected_objects is not None

    def execute(self, context):
        sel = list(context.selected_objects) if context.selected_objects else []
        for obj in sel:
            try: obj.pass_index = int(self.pass_index)
            except Exception: pass
        self.report({'INFO'}, f"Set Pass Index = {self.pass_index} on {len(sel)} object(s).")
        return {'FINISHED'}

# ---------- Properties ----------
class ShaderSwitcherProps(bpy.types.PropertyGroup):
    mode: bpy.props.EnumProperty(
        name="Node Type",
        description=("Choose how content is combined:\n"
                    "• Switch: Switch between shaders using a number (1–N)\n"
                    "• Mix: Mix shaders with opacity control\n"
                    "• Random: Switch shaders using a random pattern\n"
                    "• Color Mixer: Blend colors with factors and blend modes\n"
                    "• Master Material: Route objects via Object Pass Index"),
        items=[
            ('Switch', "Shader Switcher", "Switch between shaders using an index value"),
            ('Mix', "Shader Mixer", "Mix shaders with opacity control"),
            ('Random', "Shader Randomizer", "Switch shaders using a random pattern"),
            ('ColorMixer', "Color Mixer", "Blend colors with factors and blend modes"),
            ('Master', "Master Material", "Route objects to shader branches using Object Pass Index"),
        ],
        default='Switch',
        update=update_shader_node_group_mode
    )

class ShaderSwitcherAddonProps(bpy.types.PropertyGroup):
    creation_type: bpy.props.EnumProperty(
        name="Type",
        items=[
            ('Switch', "Shader Switcher", "Add a Shader Switcher node"),
            ('Mix', "Shader Mixer", "Add a Shader Mixer node"),
            ('Random', "Shader Randomizer", "Add a Shader Randomizer node"),
            ('ColorMixer', "Color Mixer", "Blend colors with factors and blend modes"),
            ('Master', "Master Material", "Route objects to shader branches using Object Pass Index"),
        ],
        default='ColorMixer'
    )
    initial_input_count: bpy.props.IntProperty(name="Initial Inputs", default=3, min=1, max=50)

# ---------- Cleanup ----------
class NODE_OT_remove_unused_switchers(bpy.types.Operator):
    """Remove all unused Shader Switcher/Color Mixer/Master groups."""
    bl_idname = "node.remove_unused_shader_switchers"
    bl_label = "Remove Unused Groups"
    bl_description = "Remove all unused Shader Switcher/Color Mixer/Master Material node groups"
    bl_options = {'REGISTER', 'UNDO'}

    @classmethod
    def poll(cls, context):
        return True

    def execute(self, context):
        removed = 0
        all_node_groups_list = list(bpy.data.node_groups)

        for group in all_node_groups_list:
            if "mode" in group and group.get("mode") in ['Switch', 'Mix', 'Random', 'ColorMixer', 'Master']:
                is_used = False

                for mat in bpy.data.materials:
                    if mat.use_nodes and mat.node_tree:
                        for node in mat.node_tree.nodes:
                            if node.type == 'GROUP' and node.node_tree == group:
                                is_used = True; break
                    if is_used: break
                if is_used: continue

                for ng_other in all_node_groups_list:
                    if ng_other != group:
                        for node in ng_other.nodes:
                            if node.type == 'GROUP' and node.node_tree == group:
                                is_used = True; break
                    if is_used: break
                if is_used: continue

                group_is_active_in_editor = False
                for area in context.screen.areas:
                    if area.type == 'NODE_EDITOR':
                        if hasattr(area, 'node_tree') and area.node_tree == group:
                            group_is_active_in_editor = True; break
                if group_is_active_in_editor:
                    self.report({'WARNING'}, f"Skipping '{group.name}': currently open in a Node Editor.")
                    continue

                bpy.data.node_groups.remove(group)
                removed += 1
        self.report({'INFO'}, f"Removed {removed} unused node group{'s' if removed != 1 else ''}")
        return {'FINISHED'}

# ---------- Registration ----------
classes = (
    OBJECT_OT_assign_selection_to_branch,
    NODE_OT_add_shader_input,
    NODE_OT_add_shader_switcher,
    NODE_OT_remove_shader_input,
    NODE_OT_remove_unused_switchers,
    ShaderSwitcherProps,
    ShaderSwitcherAddonProps,
)

def register():
    for cls in classes:
        bpy.utils.register_class(cls)

def unregister():
    for cls in reversed(classes):
        bpy.utils.unregister_class(cls)
